{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.ResNetPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResNetPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">This is an unofficial PyTorch implementation created by Ignacio Oguiza - oguiza@timeseriesAI.co"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from fastai.layers import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ResBlockPlus(Module):\n",
    "    def __init__(self, ni, nf, ks=[7, 5, 3], coord=False, separable=False, bn_1st=True, zero_norm=False, sa=False, se=None, act=nn.ReLU, act_kwargs={}):\n",
    "        self.convblock1 = ConvBlock(\n",
    "            ni, nf, ks[0], coord=coord, separable=separable, bn_1st=bn_1st, act=act, act_kwargs=act_kwargs)\n",
    "        self.convblock2 = ConvBlock(\n",
    "            nf, nf, ks[1], coord=coord, separable=separable, bn_1st=bn_1st, act=act, act_kwargs=act_kwargs)\n",
    "        self.convblock3 = ConvBlock(\n",
    "            nf, nf, ks[2], coord=coord, separable=separable, zero_norm=zero_norm, act=None)\n",
    "        self.se = SEModule1d(\n",
    "            nf, reduction=se, act=act) if se and nf//se > 0 else noop\n",
    "        self.sa = SimpleSelfAttention(nf, ks=1) if sa else noop\n",
    "        self.shortcut = BN1d(ni) if ni == nf else ConvBlock(\n",
    "            ni, nf, 1, coord=coord, act=None)\n",
    "        self.add = Add()\n",
    "        self.act = act(**act_kwargs)\n",
    "\n",
    "        self._init_cnn(self)\n",
    "\n",
    "    def _init_cnn(self, m):\n",
    "        if getattr(self, 'bias', None) is not None:\n",
    "            nn.init.constant_(self.bias, 0)\n",
    "        if isinstance(self, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):\n",
    "            nn.init.kaiming_normal_(self.weight)\n",
    "        for l in m.children():\n",
    "            self._init_cnn(l)\n",
    "\n",
    "    def forward(self, x):\n",
    "        res = x\n",
    "        x = self.convblock1(x)\n",
    "        x = self.convblock2(x)\n",
    "        x = self.convblock3(x)\n",
    "        x = self.se(x)\n",
    "        x = self.sa(x)\n",
    "        x = self.add(x, self.shortcut(res))\n",
    "        x = self.act(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "@delegates(ResBlockPlus.__init__)\n",
    "class ResNetPlus(nn.Sequential):\n",
    "    def __init__(self, c_in, c_out, seq_len=None, nf=64, sa=False, se=None, fc_dropout=0., concat_pool=False,\n",
    "                 flatten=False, custom_head=None, y_range=None, **kwargs):\n",
    "\n",
    "        resblock1 = ResBlockPlus(c_in,   nf,     se=se,   **kwargs)\n",
    "        resblock2 = ResBlockPlus(nf,     nf * 2, se=se,   **kwargs)\n",
    "        resblock3 = ResBlockPlus(nf * 2, nf * 2, sa=sa, **kwargs)\n",
    "        backbone = nn.Sequential(resblock1, resblock2, resblock3)\n",
    "        \n",
    "        self.head_nf = nf * 2\n",
    "        if flatten:\n",
    "            assert seq_len is not None, \"you need to pass seq_len when flatten=True\"\n",
    "            self.head_nf *= seq_len\n",
    "        if custom_head is not None:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, c_out, seq_len)\n",
    "        else:\n",
    "            head = self.create_head(self.head_nf, c_out, flatten=flatten,\n",
    "                                         concat_pool=concat_pool, fc_dropout=fc_dropout, y_range=y_range)\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "            \n",
    "    def create_head(self, nf, c_out, flatten=False, concat_pool=False, fc_dropout=0., y_range=None, **kwargs):\n",
    "        layers = [Flatten()] if flatten else []\n",
    "        if concat_pool:\n",
    "            nf = nf * 2\n",
    "        layers = [GACP1d(1) if concat_pool else GAP1d(1)]\n",
    "        if fc_dropout:\n",
    "            layers += [nn.Dropout(fc_dropout)]\n",
    "        layers += [nn.Linear(nf, c_out)]\n",
    "        if y_range:\n",
    "            layers += [SigmoidRange(*y_range)]\n",
    "        return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.layers import Swish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xb = torch.rand(2, 3, 4)\n",
    "test_eq(ResNetPlus(3,2)(xb).shape, [xb.shape[0], 2])\n",
    "test_eq(ResNetPlus(3,2,coord=True, separable=True, bn_1st=False, zero_norm=True, act=Swish, act_kwargs={}, fc_dropout=0.5)(xb).shape, [xb.shape[0], 2])\n",
    "test_eq(count_parameters(ResNetPlus(3, 2)), 479490) # for (3,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.ResNet import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(count_parameters(ResNet(3, 2)), count_parameters(ResNetPlus(3, 2))) # for (3,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_params: 114820\n",
      "ResNetPlus(\n",
      "  (backbone): Sequential(\n",
      "    (0): ResBlockPlus(\n",
      "      (convblock1): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), groups=4, bias=False)\n",
      "          (pointwise_conv): Conv1d(4, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (convblock2): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(65, 65, kernel_size=(5,), stride=(1,), padding=(2,), groups=65, bias=False)\n",
      "          (pointwise_conv): Conv1d(65, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (convblock3): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(65, 65, kernel_size=(3,), stride=(1,), padding=(1,), groups=65, bias=False)\n",
      "          (pointwise_conv): Conv1d(65, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "      (shortcut): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): Conv1d(4, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "      (add): Add\n",
      "      (act): ReLU()\n",
      "    )\n",
      "    (1): ResBlockPlus(\n",
      "      (convblock1): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(65, 65, kernel_size=(7,), stride=(1,), padding=(3,), groups=65, bias=False)\n",
      "          (pointwise_conv): Conv1d(65, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (convblock2): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(129, 129, kernel_size=(5,), stride=(1,), padding=(2,), groups=129, bias=False)\n",
      "          (pointwise_conv): Conv1d(129, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (convblock3): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(129, 129, kernel_size=(3,), stride=(1,), padding=(1,), groups=129, bias=False)\n",
      "          (pointwise_conv): Conv1d(129, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "      (shortcut): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): Conv1d(65, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "      (add): Add\n",
      "      (act): ReLU()\n",
      "    )\n",
      "    (2): ResBlockPlus(\n",
      "      (convblock1): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(129, 129, kernel_size=(7,), stride=(1,), padding=(3,), groups=129, bias=False)\n",
      "          (pointwise_conv): Conv1d(129, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (convblock2): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(129, 129, kernel_size=(5,), stride=(1,), padding=(2,), groups=129, bias=False)\n",
      "          (pointwise_conv): Conv1d(129, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (3): ReLU()\n",
      "      )\n",
      "      (convblock3): ConvBlock(\n",
      "        (0): AddCoords1d()\n",
      "        (1): SeparableConv1d(\n",
      "          (depthwise_conv): Conv1d(129, 129, kernel_size=(3,), stride=(1,), padding=(1,), groups=129, bias=False)\n",
      "          (pointwise_conv): Conv1d(129, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
      "        )\n",
      "        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "      (shortcut): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (add): Add\n",
      "      (act): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (head): Sequential(\n",
      "    (0): GAP1d(\n",
      "      (gap): AdaptiveAvgPool1d(output_size=1)\n",
      "      (flatten): Reshape(bs)\n",
      "    )\n",
      "    (1): Linear(in_features=128, out_features=2, bias=True)\n",
      "  )\n",
      ")\n",
      "[1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 1.]\n"
     ]
    }
   ],
   "source": [
    "m = ResNetPlus(3, 2, zero_norm=True, coord=True, separable=True)\n",
    "print('n_params:', count_parameters(m))\n",
    "print(m)\n",
    "print(check_weight(m, is_bn)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/034_models.ResNetPlus.ipynb saved at 2023-03-19 14:25:46\n",
      "Correct notebook to script conversion! 😃\n",
      "Sunday 19/03/23 14:25:49 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
